Training LoopΒΆ

Critic Update: The critic is updated multiple times (crit_cycles) for each update of the generator. This helps improve the stability of the GAN.

  • Critic loss is calculated as the difference between the average scores for real and fake images, plus the gradient penalty term.
  • The critic is trained by minimizing this loss.

Generator Update: After the critic has been updated, the generator is updated. The generator loss is the negative mean of the critic’s predictions on the fake images.

  • The generator is trained by maximizing the critic's confusion (i.e., fooling the critic into classifying fake images as real).

Visualization and Checkpoints: Periodically, the training loop displays generated and real images, logs losses, and saves the model checkpoints.

InΒ [20]:
# Training loop

for epoch in range(n_epochs):
  for real, _ in tqdm(dataloader):
    cur_bs = len(real) # 128
    real = real.to(device)


    ## Critic
    mean_crit_loss = 0
    for _ in range(crit_cycles):

      # zeroing gradient of the optimizer
      crit_opt.zero_grad()

      noise = gen_noise(cur_bs, z_dim)
      fake = gen(noise)
      # detaching for not affecting the parameters of the generator
      crit_fake_pred = crit(fake.detach())
      crit_real_pred = crit(real)

      # alpha vector (numbers size of the batch)
      alpha = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True) # 128 x 1 x 1 x 1

      # calculating gradient penalty
      gp = get_gp(real, fake.detach(), crit, alpha)

      # calculating loss
      crit_loss = crit_fake_pred.mean() - crit_real_pred.mean() + gp

      #.item - taking only the number from the tensor
      mean_crit_loss += crit_loss.item() / crit_cycles

      # optimizer backpropagation
      crit_loss.backward(retain_graph=True)
      crit_opt.step()

    # list of losses values
    crit_losses += [mean_crit_loss]


    ## Generator

    # zeroing gradient of the optimizer
    gen_opt.zero_grad()

    # creating noise 128 x 200
    noise = gen_noise(cur_bs, z_dim)

    # passing noise through generator
    fake = gen(noise)

    # passing them through critic
    crit_fake_pred = crit(fake)

    # negative of the pred of the critic
    gen_loss = -crit_fake_pred.mean()

    # backpropagation
    gen_loss.backward()

    # updating the parameters of the generator
    gen_opt.step()

    gen_losses+=[gen_loss.item()]


    if cur_step % save_step == 0 and cur_step > 0:
      print('Saving checkpoint:', cur_step, save_step)
      # best to save the files with the different names fe. nr of epoch
      save_checkpoint('latest')

    if (cur_step % show_step == 0 and cur_step > 0):
      show(fake, wandbactivation=1, name='fake')
      show(real, wandbactivation=1, name='real')

      gen_mean = sum(gen_losses[-show_step:]) / show_step
      crit_mean = sum(crit_losses[-show_step:]) / show_step
      print(f'Epoch: {epoch}, step: {cur_step}, Generator loss: {gen_loss}, Critic loss: {crit_loss}')

      plt.plot(range(len(gen_losses)),
               torch.Tensor(gen_losses),
               label='Generator loss')

      plt.plot(range(len(crit_losses)),
               torch.Tensor(crit_losses),
               label='Critic loss')

      plt.ylim(-200,200)
      plt.legend()
      plt.show()
    cur_step += 1
  0%|          | 0/391 [00:00<?, ?it/s]
Saving checkpoint: 10045 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 10045, Generator loss: 6.669975280761719, Critic loss: -3.8914854526519775
No description has been provided for this image
Saving checkpoint: 10080 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 10080, Generator loss: 10.462503433227539, Critic loss: -4.360108852386475
No description has been provided for this image
Saving checkpoint: 10115 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 10115, Generator loss: 4.7965593338012695, Critic loss: -3.7804861068725586
No description has been provided for this image
Saving checkpoint: 10150 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 10150, Generator loss: 4.233477592468262, Critic loss: -5.056338787078857
No description has been provided for this image
Saving checkpoint: 10185 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 10185, Generator loss: 5.628297328948975, Critic loss: -4.169041156768799
No description has been provided for this image
Saving checkpoint: 10220 35
Saved checkpoint
No description has been provided for this image
No description has been provided for this image
Epoch: 0, step: 10220, Generator loss: 12.158222198486328, Critic loss: -4.853008270263672
No description has been provided for this image
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[20], line 6
      4 for real, _ in tqdm(dataloader):
      5   cur_bs = len(real) # 128
----> 6   real = real.to(device)
      9   ## Critic
     10   mean_crit_loss = 0

KeyboardInterrupt: 

10000 / 128 = 78.125 - 79 steps per epoch

Summary of the GAN Workflow:ΒΆ

  • Noise Generation: Random noise is generated and passed into the generator.
  • Image Generation: The generator transforms the noise into a synthetic image.
  • Critic Evaluation: The critic evaluates both real and fake images, trying to distinguish between them.

Loss Calculation:

  • The generator tries to minimize the critic's ability to detect fakes.
  • The critic tries to maximize the distinction between real and fake images while ensuring a smooth gradient.

Training Loop: Alternates between training the generator and the critic, improving the quality of the generated images over time. This code builds a GAN using best practices, such as using gradient penalties and properly balancing critic and generator updates, making it suitable for generating high-quality images.